{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "# 下载mnist 训练集\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "data_path = '../data-unversioned/mnist/'\n",
    "mnist = datasets.MNIST(\n",
    "    data_path, train=True, download=True,\n",
    "    transform=transforms.Compose([\n",
    "        transforms.Resize(28),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4915),\n",
    "                             (0.2470))\n",
    "    ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# 下载mnist验证集\n",
    "mnist_val = datasets.MNIST(\n",
    "    data_path, train=False, download=True,\n",
    "    transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.4915),\n",
    "                             (0.2470))\n",
    "    ])\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD4CAYAAAAq5pAIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAN80lEQVR4nO3df6hcdXrH8c+ncf3DrBpTMYasNhuRWBWbLRqLSl2RrD9QNOqWDVgsBrN/GHChhEr6xyolEuqP0qAsuYu6sWyzLqgYZVkVo6ZFCF5j1JjU1YrdjV6SSozG+KtJnv5xT+Su3vnOzcyZOZP7vF9wmZnzzJnzcLife87Md879OiIEYPL7k6YbANAfhB1IgrADSRB2IAnCDiRxRD83ZpuP/oEeiwiPt7yrI7vtS22/aftt27d281oAesudjrPbniLpd5IWSNou6SVJiyJia2EdjuxAj/XiyD5f0tsR8U5EfCnpV5Ku6uL1APRQN2GfJekPYx5vr5b9EdtLbA/bHu5iWwC61M0HdOOdKnzjND0ihiQNSZzGA03q5si+XdJJYx5/R9L73bUDoFe6CftLkk61/V3bR0r6kaR19bQFoG4dn8ZHxD7bSyU9JWmKpAci4o3aOgNQq46H3jraGO/ZgZ7ryZdqABw+CDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUii4ymbcXiYMmVKsX7sscf2dPtLly5tWTvqqKOK686dO7dYv/nmm4v1u+66q2Vt0aJFxXU///zzYn3lypXF+u23316sN6GrsNt+V9IeSfsl7YuIs+toCkD96jiyXxQRH9TwOgB6iPfsQBLdhj0kPW37ZdtLxnuC7SW2h20Pd7ktAF3o9jT+/Ih43/YJkp6x/V8RsWHsEyJiSNKQJNmOLrcHoENdHdkj4v3qdqekxyTNr6MpAPXrOOy2p9o++uB9ST+QtKWuxgDUq5vT+BmSHrN98HX+PSJ+W0tXk8zJJ59crB955JHF+nnnnVesX3DBBS1r06ZNK6577bXXFutN2r59e7G+atWqYn3hwoUta3v27Cmu++qrrxbrL7zwQrE+iDoOe0S8I+kvauwFQA8x9AYkQdiBJAg7kARhB5Ig7EASjujfl9om6zfo5s2bV6yvX7++WO/1ZaaD6sCBA8X6jTfeWKx/8sknHW97ZGSkWP/www+L9TfffLPjbfdaRHi85RzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtlrMH369GJ948aNxfqcOXPqbKdW7XrfvXt3sX7RRRe1rH355ZfFdbN+/6BbjLMDyRF2IAnCDiRB2IEkCDuQBGEHkiDsQBJM2VyDXbt2FevLli0r1q+44opi/ZVXXinW2/1L5ZLNmzcX6wsWLCjW9+7dW6yfccYZLWu33HJLcV3UiyM7kARhB5Ig7EAShB1IgrADSRB2IAnCDiTB9ewD4JhjjinW200vvHr16pa1xYsXF9e9/vrri/W1a9cW6xg8HV/PbvsB2zttbxmzbLrtZ2y/Vd0eV2ezAOo3kdP4X0i69GvLbpX0bEScKunZ6jGAAdY27BGxQdLXvw96laQ11f01kq6uty0Adev0u/EzImJEkiJixPYJrZ5oe4mkJR1uB0BNen4hTEQMSRqS+IAOaFKnQ287bM+UpOp2Z30tAeiFTsO+TtIN1f0bJD1eTzsAeqXtabzttZK+L+l429sl/VTSSkm/tr1Y0u8l/bCXTU52H3/8cVfrf/TRRx2ve9NNNxXrDz/8cLHebo51DI62YY+IRS1KF9fcC4Ae4uuyQBKEHUiCsANJEHYgCcIOJMElrpPA1KlTW9aeeOKJ4roXXnhhsX7ZZZcV608//XSxjv5jymYgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSIJx9knulFNOKdY3bdpUrO/evbtYf+6554r14eHhlrX77ruvuG4/fzcnE8bZgeQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtmTW7hwYbH+4IMPFutHH310x9tevnx5sf7QQw8V6yMjIx1vezJjnB1IjrADSRB2IAnCDiRB2IEkCDuQBGEHkmCcHUVnnnlmsX7PPfcU6xdf3Plkv6tXry7WV6xYUay/9957HW/7cNbxOLvtB2zvtL1lzLLbbL9ne3P1c3mdzQKo30RO438h6dJxlv9LRMyrfn5Tb1sA6tY27BGxQdKuPvQCoIe6+YBuqe3XqtP841o9yfYS28O2W/8zMgA912nYfybpFEnzJI1IurvVEyNiKCLOjoizO9wWgBp0FPaI2BER+yPigKSfS5pfb1sA6tZR2G3PHPNwoaQtrZ4LYDC0HWe3vVbS9yUdL2mHpJ9Wj+dJCknvSvpxRLS9uJhx9sln2rRpxfqVV17ZstbuWnl73OHir6xfv75YX7BgQbE+WbUaZz9iAisuGmfx/V13BKCv+LoskARhB5Ig7EAShB1IgrADSXCJKxrzxRdfFOtHHFEeLNq3b1+xfskll7SsPf/888V1D2f8K2kgOcIOJEHYgSQIO5AEYQeSIOxAEoQdSKLtVW/I7ayzzirWr7vuumL9nHPOaVlrN47eztatW4v1DRs2dPX6kw1HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2SW7u3LnF+tKlS4v1a665plg/8cQTD7mnidq/f3+xPjJS/u/lBw4cqLOdwx5HdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgnH2w0C7sexFi8abaHdUu3H02bNnd9JSLYaHh4v1FStWFOvr1q2rs51Jr+2R3fZJtp+zvc32G7ZvqZZPt/2M7beq2+N63y6ATk3kNH6fpL+PiD+X9FeSbrZ9uqRbJT0bEadKerZ6DGBAtQ17RIxExKbq/h5J2yTNknSVpDXV09ZIurpHPQKowSG9Z7c9W9L3JG2UNCMiRqTRPwi2T2ixzhJJS7rsE0CXJhx229+W9Iikn0TEx/a4c8d9Q0QMSRqqXoOJHYGGTGjozfa3NBr0X0bEo9XiHbZnVvWZknb2pkUAdWh7ZPfoIfx+Sdsi4p4xpXWSbpC0srp9vCcdTgIzZswo1k8//fRi/d577y3WTzvttEPuqS4bN24s1u+8886WtccfL//KcIlqvSZyGn++pL+V9LrtzdWy5RoN+a9tL5b0e0k/7EmHAGrRNuwR8Z+SWr1Bv7jedgD0Cl+XBZIg7EAShB1IgrADSRB2IAkucZ2g6dOnt6ytXr26uO68efOK9Tlz5nTSUi1efPHFYv3uu+8u1p966qli/bPPPjvkntAbHNmBJAg7kARhB5Ig7EAShB1IgrADSRB2IIk04+znnntusb5s2bJiff78+S1rs2bN6qinunz66acta6tWrSque8cddxTre/fu7agnDB6O7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRJpx9oULF3ZV78bWrVuL9SeffLJY37dvX7FeuuZ89+7dxXWRB0d2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUjCEVF+gn2SpIcknSjpgKShiPhX27dJuknS/1ZPXR4Rv2nzWuWNAehaRIw76/JEwj5T0syI2GT7aEkvS7pa0t9I+iQi7ppoE4Qd6L1WYZ/I/Owjkkaq+3tsb5PU7L9mAXDIDuk9u+3Zkr4naWO1aKnt12w/YPu4FusssT1se7i7VgF0o+1p/FdPtL8t6QVJKyLiUdszJH0gKST9k0ZP9W9s8xqcxgM91vF7dkmy/S1JT0p6KiLuGac+W9KTEXFmm9ch7ECPtQp729N425Z0v6RtY4NefXB30EJJW7ptEkDvTOTT+Ask/Yek1zU69CZJyyUtkjRPo6fx70r6cfVhXum1OLIDPdbVaXxdCDvQex2fxgOYHAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ9HvK5g8k/c+Yx8dXywbRoPY2qH1J9NapOnv7s1aFvl7P/o2N28MRcXZjDRQMam+D2pdEb53qV2+cxgNJEHYgiabDPtTw9ksGtbdB7Uuit071pbdG37MD6J+mj+wA+oSwA0k0Enbbl9p+0/bbtm9toodWbL9r+3Xbm5uen66aQ2+n7S1jlk23/Yztt6rbcefYa6i322y/V+27zbYvb6i3k2w/Z3ub7Tds31Itb3TfFfrqy37r+3t221Mk/U7SAknbJb0kaVFEbO1rIy3YflfS2RHR+BcwbP+1pE8kPXRwai3b/yxpV0SsrP5QHhcR/zAgvd2mQ5zGu0e9tZpm/O/U4L6rc/rzTjRxZJ8v6e2IeCcivpT0K0lXNdDHwIuIDZJ2fW3xVZLWVPfXaPSXpe9a9DYQImIkIjZV9/dIOjjNeKP7rtBXXzQR9lmS/jDm8XYN1nzvIelp2y/bXtJ0M+OYcXCarer2hIb7+bq203j309emGR+YfdfJ9OfdaiLs401NM0jjf+dHxF9KukzSzdXpKibmZ5JO0egcgCOS7m6ymWqa8Uck/SQiPm6yl7HG6asv+62JsG+XdNKYx9+R9H4DfYwrIt6vbndKekyjbzsGyY6DM+hWtzsb7ucrEbEjIvZHxAFJP1eD+66aZvwRSb+MiEerxY3vu/H66td+ayLsL0k61fZ3bR8p6UeS1jXQxzfYnlp9cCLbUyX9QIM3FfU6STdU92+Q9HiDvfyRQZnGu9U042p43zU+/XlE9P1H0uUa/UT+vyX9YxM9tOhrjqRXq583mu5N0lqNntb9n0bPiBZL+lNJz0p6q7qdPkC9/ZtGp/Z+TaPBmtlQbxdo9K3ha5I2Vz+XN73vCn31Zb/xdVkgCb5BByRB2IEkCDuQBGEHkiDsQBKEHUiCsANJ/D+f1mbt6t55/AAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 选择一张图片看看\n",
    "img, _ = mnist[0]\n",
    "img.shape\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "plt.imshow(img.mean(0), cmap='gray')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The autoreload extension is already loaded. To reload it, use:\n",
      "  %reload_ext autoreload\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import net\n",
    "%load_ext autoreload\n",
    "%autoreload 1\n",
    "%aimport net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "model = net.Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "tensor([[-2.3702, -2.3702, -2.3702, -2.1666, -2.3702, -2.3702, -2.3417, -2.3702,\n         -2.0021, -2.3702]], grad_fn=<LogSoftmaxBackward>)"
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data, target = mnist_val[0]\n",
    "model(data.unsqueeze(0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "device = (torch.device('cuda') if torch.cuda.is_available()\n",
    "          else torch.device(\"cpu\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import datetime\n",
    "\n",
    "\n",
    "def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):\n",
    "    for epoch in range(1, n_epochs + 1):\n",
    "        loss_train = 0.0\n",
    "        for batch_idx, (data, target) in enumerate(train_loader):\n",
    "            data = data.to(device=device)\n",
    "            target = target.to(device=device)\n",
    "            optimizer.zero_grad()\n",
    "            output = model(data)\n",
    "            loss = loss_fn(output, target)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            loss_train += loss.item()\n",
    "        if epoch == 1 or epoch % 10 == 0:\n",
    "            print('{} Epoch {}, Training loss {}'.format(\n",
    "                datetime.datetime.now(), epoch,\n",
    "                loss_train / len(train_loader)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n",
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "optimizer = optim.SGD(model.parameters(), lr=1e-2,weight_decay = 0.0001, momentum=0.5)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "train_loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)\n",
    "training_loop(3, optimizer, model, loss_fn, train_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy train: 0.88\n",
      "Accuracy val: 0.88\n"
     ]
    }
   ],
   "source": [
    "import collections\n",
    "from util import validate\n",
    "\n",
    "all_acc_dict = collections.OrderedDict()\n",
    "val_loader = torch.utils.data.DataLoader(mnist_val, 64, shuffle=False)\n",
    "all_acc_dict['resnet'] = validate(model, train_loader, val_loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "OrderedDict([('resnet', {'train': 0.8825833333333334, 'val': 0.8847})])"
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_acc_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "<Figure size 432x288 with 1 Axes>",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAY0AAAEgCAYAAABIJS/hAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaZUlEQVR4nO3de5RdZZnn8e9DURBykYQQYkwBicgQI0LUMjDiKHjBgCJex6BL7YimaaGbQXSIOra0Lcpq1tjKiIbYk6FVIqKSJmoURLm1wpAKBkMikRCCKSLmgtwc0yTwzB97B4+VqmSf1KlzDsn3s1atOnu/797nCYvwY+/33fuNzESSpCr2aXUBkqRnD0NDklSZoSFJqszQkCRVZmhIkiozNCRJlTU1NCJifkRsiIi7B2iPiLg0IlZHxK8i4qU1bTMiYlXZNqd5VUuStmv2lcYVwIydtJ8CHFn+zAa+ChARHcBlZftU4IyImDqklUqSdtDU0MjMW4CHd9LldODrWbgdGB0RE4DpwOrMXJOZTwJXlX0lSU3UbmMaE4F1Ndu95b6B9kuSmmjfVhfQR/SzL3eyv/+TRMymuL3FiBEjXjZlypTGVCdJe4GlS5duysxx/bW1W2j0AofWbHcB64H9Btjfr8ycB8wD6O7uzp6ensZXKkl7qIh4YKC2drs9tQh4XzmL6njg0cz8HbAEODIiJkfEfsDMsq8kqYmaeqUREd8CTgQOjohe4NNAJ0BmzgUWA6cCq4H/B8wq27ZFxDnAdUAHMD8zVzSzdklSk0MjM8/YRXsCZw/QtpgiVCRJLdJut6ckSW3M0JAkVWZoSJIqMzQkSZUZGpKkygwNSVJlhoYkqTJDQ5JUmaEhSarM0JAkVWZoSJIqMzQkSZUZGpKkygwNSVJlhoYkqTJDQ5JUmaEhSarM0JAkVWZoSJIqMzQkSZUZGpKkygwNSVJlhoYkqTJDQ5JUmaEhSarM0JAkVWZoSJIqa3poRMSMiFgVEasjYk4/7WMiYmFE/Coi7oiIo2va1kbE8ohYFhE9za1ckrRvM78sIjqAy4DXA73AkohYlJkra7p9AliWmW+NiCll/9fWtJ+UmZuaVrQk6RnNvtKYDqzOzDWZ+SRwFXB6nz5TgZ8CZOY9wKSIGN/cMiVJ/Wl2aEwE1tVs95b7at0FvA0gIqYDhwNdZVsC10fE0oiYPcS1SpL6aOrtKSD62Zd9ti8GvhQRy4DlwC+BbWXbCZm5PiIOAX4SEfdk5i07fEkRKLMBDjvssEbVLkl7vWZfafQCh9ZsdwHraztk5mOZOSszpwHvA8YB95dt68vfG4CFFLe7dpCZ8zKzOzO7x40b1/A/hCTtrZodGkuAIyNickTsB8wEFtV2iIjRZRvAB4FbMvOxiBgREaPKPiOAk4G7m1i7JO31mnp7KjO3RcQ5wHVABzA/M1dExFll+1zghcDXI+IpYCVwZnn4eGBhRGyve0Fm/riZ9UvS3i4y+w4p7Fm6u7uzp8dHOiSpqohYmpnd/bX5RLgkqTJDQ5JUmaEhSarM0JAkVWZoSJIqMzQkSZUZGpKkygwNSVJlhoYkqTJDQ5JUmaEhSarM0JAkVWZoSJIqMzQkSZUZGpKkygwNSVJlhoYkqTJDQ5JUmaEhSarM0JAkVWZoSJIqMzQkSZUZGpKkygwNSVJlhoYkqTJDQ5JUmaEhSarM0JAkVdb00IiIGRGxKiJWR8ScftrHRMTCiPhVRNwREUdXPVaSNLSaGhoR0QFcBpwCTAXOiIipfbp9AliWmccA7wO+VMexkqQh1OwrjenA6sxck5lPAlcBp/fpMxX4KUBm3gNMiojxFY+VJA2hZofGRGBdzXZvua/WXcDbACJiOnA40FXxWMrjZkdET0T0bNy4sUGlS5KaHRrRz77ss30xMCYilgF/C/wS2Fbx2GJn5rzM7M7M7nHjxg2iXElSrX2b/H29wKE1213A+toOmfkYMAsgIgK4v/wZvqtjJUlDq/KVRkT8PCLeGxH7D+L7lgBHRsTkiNgPmAks6vM9o8s2gA8Ct5RBsstjJUlDq57bU1uBfwXWR8QXImJKvV+WmduAc4DrgF8DV2fmiog4KyLOKru9EFgREfdQzJQ6d2fH1luDJGn3RWa/wwL9d444CvhriqmwY4Bbga8C12Tm1iGpcJC6u7uzp6en1WVI0rNGRCzNzO7+2uoaCM/MVZn5EYpZS38FdAALgN6IuDginj/YYiVJ7Wu3Zk9l5n9k5jcobh3dCowD/jvwm4j4TkQ8t4E1SpLaRN2hEREHRMQHIuIOisHpcRTh8Tzgb4BXAFc2tEpJUluoPOU2Il5MMZ7xHmAEcC1wQWbeWNPtaxHxEPCdhlYpSWoL9TyncRfFcxFfBOZl5u8G6LcauG2QdUmS2lA9ofFO4N8y86mddcrMXwMnDaoqaW9w4YGtrqB5Lny01RWoQeoZ01gEDOuvISJGRERnY0qSJLWreq40/gXoBN7dT9vlwJPABxpRlPZek+b8sNUlNM3afv8XTGpv9VxpnEQx+N2fRcBrB1+OJKmd1RMahwAbBmjbCIwffDmSpHZWT2hsAF48QNuLgc2DL0eS1M7qCY0fAJ+KiGNqd5bPb3wS+H4jC5MktZ96BsL/Hng9sDQilvDnlfOmU6x38T8aX54kqZ1UvtLIzE3Ay4HPU6yiN638fRHw8rJdkrQHq2vlvsx8hOKK4++HpBpJUltr9hrhkqRnsbquNCLiaOBM4Ch2fDo8M9NnNSRpD1bPW26PA24G1gJHAr+iWL3vMIpB8dVDUJ8kqY3Uc3vqc8A1wIsoBsDPzMxJwOsoVvD7bMOrkyS1lXpC4xjgm8D2RcU7ADLzZxSB8fnGliZJajf1hEYn8MfMfBp4GJhQ07YKOLqRhUmS2k89oXEfxcN8UIxnfCAi9omIfYBZwEONLk6S1F7qmT31A+BEYAHF+MYPgceAp4CRwN81ujhJUnupHBqZ+emazzdExPHA24HhwI8z8/ohqE+S1EYqhUa5Kt+pwK8y836AzPwl8MshrE2S1GYqjWlk5lbgamDSkFYjSWpr9QyEr6FYiEmStJeqJzT+CfhkRIwbzBdGxIyIWBURqyNiTj/tB0bE9yPirohYERGzatrWRsTyiFgWET2DqUOSVL96Zk+9BjgIuD8ibgd+x58f9IPi3VPv39kJIqIDuIxiXY5eYElELMrMlTXdzgZWZuZpZUCtiogrM/PJsv0kX8MuSa1RT2i8EthKsR74EeVPrdzhiB1NB1Zn5hqAiLgKOB2oDY0ERkVEUEzlfRjYVkedkqQhUs+U28kN+L6JwLqa7V7guD59vgwsAtYDo4B3lU+hQxEo10dEApdn5rwG1CRJqqjZ62lEP/v6XqG8AVgGPI9idcAvR8RzyrYTMvOlwCnA2RHxqn6/JGJ2RPRERM/GjRsbUrgkqY7QiIjDdvVT4TS9wKE1210UVxS1ZgHXZGE1xfrjUwAyc335ewOwkOJ21w4yc15mdmdm97hxgxq3lyTVqGdMYy27Hrfo2EX7EuDIiJgMPAjMBN7dp89vgdcCt0bEeIoFn9ZExAhgn8x8vPx8MvCZOuqXJA1SPaHxAXYMjbHAG4HnA/+4qxNk5raIOAe4jiJg5mfmiog4q2yfW57niohYTnE764LM3BQRzwcWFuPj7AssyMwf11G/JGmQ6hkIv2KApi9ExDcogqPKeRYDi/vsm1vzeT3FVUTf49YAx1atV5LUeI0aCP8mxZWIJGkP1qjQOAQY1qBzSZLaVOXbUwNMb92PYsW+jwO3NqooSVJ7qmcg/CZ2HAjf/tzFzcDfNKIgSVL7qic0Tupn3xbggcx0qVdJ2gvUM3vq5qEsRJLU/up5Ivz4iPivA7S9MyL6vkNKkrSHqWf21OeBFw3Q9sKyXZK0B6snNI4Fbh+g7Q7gmMGXI0lqZ/WExrCd9O8ARgy+HElSO6snNH4NvHmAtjcDqwZfjiSpndUz5XYucHlEPAZ8jeI15xOB2cCZwIcbX54kqZ3UM+X2axFxFHAe8JHaJuCfXUVPkvZ89VxpkJkfjYivAq+jeC36JuCG7Wt+S5L2bHWFBkBm3gfcNwS1SJLaXD0P982KiAsHaLswIt7fsKokSW2pntlT5wKbB2jbAPy3QVcjSWpr9YTGC4AVA7T9Gjhi8OVIktpZPaGxDTh4gLZxDahFktTm6gmNO4CzBmg7C1gy+HIkSe2sntlTFwE3RMT/Bf4FeJDi4b4PAi8FXt/48iRJ7aSu9TQi4h3AF4HLa5rWAm/PzJsaWpkkqe3U+3DftcC15ZPhY4FNmfmbIalMktR26n64DyAzfTmhJO2F6g6NiDgWOIriVel/ITO/3oiiJEntqXJoRMRo4IfA8dt3lb+zppuhIUl7sHqm3H6OYhzjVRSB8VbgNcCVwBpgesOrkyS1lXpC4w0UwbF9ydfezLwpM98H3EDxmpFdiogZEbEqIlZHxJx+2g+MiO9HxF0RsSIiZlU9VpI0tOoJjQnAmsx8CtgCjKppuwZ4465OEBEdwGXAKcBU4IyImNqn29nAysw8FjgR+J8RsV/FYyVJQ6ie0HgIGF1+fgD4zzVtL6h4junA6sxck5lPAlcBp/fpk8CoiAhgJPAwxStMqhwrSRpC9cye+neKoPgB8A3g0xExieI/6O8HFlU4x0RgXc12L3Bcnz5fLs+1nuJq5l2Z+XREVDlWkjSE6gmNfwCeV36+hGJQ/F3AcIr/yP9thXNEP/uyz/YbgGUUg+xHAD+JiFsrHlt8ScRsirXLOeywwyqUJUmqovLtqcy8LzNvLT9vzczzM7MrMw/KzHdn5kBrbdTqBQ6t2e6iuKKoNQu4JgurgfuBKRWP3V7rvMzszszuceN8Aa8kNUo9YxqNsAQ4MiImR8R+wEx2vK31W+C1ABExnuJBwjUVj5UkDaHdeo3I7srMbRFxDnAd0AHMz8wVEXFW2T4X+EfgiohYTnFL6oLM3ATQ37HNrF+S9nZNDQ2AzFwMLO6zb27N5/XAyVWPlSQ1T7NvT0mSnsUMDUlSZYaGJKkyQ0OSVJmhIUmqzNCQJFVmaEiSKjM0JEmVGRqSpMoMDUlSZYaGJKkyQ0OSVJmhIUmqzNCQJFVmaEiSKjM0JEmVGRqSpMoMDUlSZYaGJKkyQ0OSVJmhIUmqzNCQJFVmaEiSKjM0JEmVGRqSpMoMDUlSZYaGJKmypodGRMyIiFURsToi5vTT/rGIWFb+3B0RT0XEQWXb2ohYXrb1NLt2Sdrb7dvML4uIDuAy4PVAL7AkIhZl5srtfTLzEuCSsv9pwHmZ+XDNaU7KzE1NLFuSVGpqaADTgdWZuQYgIq4CTgdWDtD/DOBbTapNkgDYunUrvb29bNmypdWlDKlhw4bR1dVFZ2dn5WOaHRoTgXU1273Acf11jIjhwAzgnJrdCVwfEQlcnpnzhqpQSXuv3t5eRo0axaRJk4iIVpczJDKTzZs309vby+TJkysf1+wxjf7+6ecAfU8Dft7n1tQJmflS4BTg7Ih4Vb9fEjE7Inoiomfjxo2Dq1jSXmfLli2MHTt2jw0MgIhg7NixdV9NNTs0eoFDa7a7gPUD9J1Jn1tTmbm+/L0BWEhxu2sHmTkvM7szs3vcuHGDLlrS3mdPDoztdufP2OzQWAIcGRGTI2I/imBY1LdTRBwIvBq4tmbfiIgYtf0zcDJwd1OqliQBTR7TyMxtEXEOcB3QAczPzBURcVbZPrfs+lbg+sz8Y83h44GFZTLuCyzIzB83r3pJe6tJc37Y0POtvfiNO21/5JFHWLBgAR/+8IfrOu+pp57KggULGD169CCq27lmD4STmYuBxX32ze2zfQVwRZ99a4Bjh7g8SWq5Rx55hK985Ss7hMZTTz1FR0fHgMctXrx4wLZGaXpoSJJ2bs6cOdx3331MmzaNzs5ORo4cyYQJE1i2bBkrV67kLW95C+vWrWPLli2ce+65zJ49G4BJkybR09PDE088wSmnnMIrX/lKfvGLXzBx4kSuvfZaDjjggEHX5mtEJKnNXHzxxRxxxBEsW7aMSy65hDvuuIOLLrqIlSuLR9rmz5/P0qVL6enp4dJLL2Xz5s07nOPee+/l7LPPZsWKFYwePZrvfe97DanNKw1JanPTp0//i2cpLr30UhYuXAjAunXruPfeexk7duxfHDN58mSmTZsGwMte9jLWrl3bkFoMDUlqcyNGjHjm80033cQNN9zAbbfdxvDhwznxxBP7fdZi//33f+ZzR0cHf/rTnxpSi7enJKnNjBo1iscff7zftkcffZQxY8YwfPhw7rnnHm6//fam1uaVhiTtwq6myDba2LFjOeGEEzj66KM54IADGD9+/DNtM2bMYO7cuRxzzDEcddRRHH/88U2tzdCQpDa0YMGCfvfvv//+/OhHP+q3bfu4xcEHH8zdd//52eePfvSjDavL21OSpMoMDUlSZYaGJKkyQ0OSVJmhIUmqzNCQJFXmlFtJ2pULD2zw+R5t6OlGjhzJE0880dBzDsQrDUlSZV5pSFKbueCCCzj88MOfWU/jwgsvJCK45ZZb+MMf/sDWrVv57Gc/y+mnn9702rzSkKQ2M3PmTL797W8/s3311Vcza9YsFi5cyJ133smNN97I+eefT2Y2vTavNCSpzbzkJS9hw4YNrF+/no0bNzJmzBgmTJjAeeedxy233MI+++zDgw8+yO9//3ue+9znNrU2Q0OS2tA73vEOvvvd7/LQQw8xc+ZMrrzySjZu3MjSpUvp7Oxk0qRJ/b4SfagZGpLUhmbOnMmHPvQhNm3axM0338zVV1/NIYccQmdnJzfeeCMPPPBAS+oyNCRpVxo8RbaKF73oRTz++ONMnDiRCRMm8J73vIfTTjuN7u5upk2bxpQpU5peExgaktS2li9f/szngw8+mNtuu63ffs16RgOcPSVJqoOhIUmqzNCQpH604hmIZtudP6OhIUl9DBs2jM2bN+/RwZGZbN68mWHDhtV1nAPhktRHV1cXvb29bNy4sdWlDKlhw4bR1dVV1zGGhiT10dnZyeTJk1tdRltq+u2piJgREasiYnVEzOmn/WMRsaz8uTsinoqIg6ocK0kaWk0NjYjoAC4DTgGmAmdExNTaPpl5SWZOy8xpwMeBmzPz4SrHSpKGVrOvNKYDqzNzTWY+CVwF7OzdvmcA39rNYyVJDdbsMY2JwLqa7V7guP46RsRwYAZwzm4cOxuYXW4+ERGrBlGzNCQCDgY2tbqOpviHaHUFqs/hAzU0OzT6+zdnoDltpwE/z8yH6z02M+cB8+ovT2qeiOjJzO5W1yHVo9m3p3qBQ2u2u4D1A/SdyZ9vTdV7rCRpCDQ7NJYAR0bE5IjYjyIYFvXtFBEHAq8Grq33WEnS0Gnq7anM3BYR5wDXAR3A/MxcERFnle1zy65vBa7PzD/u6thm1i81mLdQ9awTe/Jj8pKkxvLdU5KkygwNSVJlhoYkqTJDQ5JUmaEhSarM0JDaSER0lr87Wl2L1B9DQ2oTETEsM7dGxCjgxxExptU1SX0ZGlL7+HRETAeuADoz8w9ecajduHKf1AYi4hUUb0J4A/Bi4G0AmflURATFg7hPt7BECfCJcKltRMQE4BfAqPL3T4FrMnNdTZ+TgZ+kf3HVIoaG1AbK21D7A+cDdwJ/BfwnYCVwDcXLOd8OfDMzva2sljE0pDYVEe8FZgEjgKeB5wNzMvP/tLQw7dUMDamFIqKjHLfoBl4FnAT0AFdk5gMRcRDwfuBA4LeZOb+F5UqGhtQqEbFPZj4dEc+lGL/4PfAz4DPAPwEf3z52sT1cWletVPDeqNQ62/+P7SvAysx8DfAd4DHgqszMiHhjRIw0MNQuDA2pRcpQOIRiGeNLy93fBL6emcsiYjjwTuDsctqt1HKGhtRamymuOLoi4k3A84DPlW0BTAU2OMVW7cKH+6QWKgfBF1Ksef8K4JLMfKi8svhrYIKzpdRODA2p9a4FTgSGA5Mj4gKK6bUnA+e2sC5pB86ekpqoZsZUJ9CVmfeX+4cBc4DXASOBB4D/nZmLWlettCNDQ2qSmmcyuoDLgJcAHcCXgf+VmU+Ub7Z9AiAzt7auWql/hobUZBFxA8Ug93eAQ4CPAJuATwNXGxZqZ86ekpogIvYpf48DhgFnZObczPwMxQypm4H5wE8j4uWtq1TaOUNDaoKa15q/BVgLHAHP3LJan5lnAq8EDgKmtaBEqRJnT0lNEhHTgLkUt6buBW4rxzj2oXjWbwlwdAtLlHbJKw2pSTJzGXAs8CXgkxFxY0Qcm5lPl0+Hu0qf2p6hIQ2hmrGM50fEuMy8G/gExTMYAHdGxFciYqzvl9KzgaEhDZFyvOLpiJhCse73f4mIYZn5p8y8iWJp1zOBNwFrI2Jk66qVqnHKrTTEIuLfgTXA+Zm5MSLGUqzK91Bm3h8RhwEvyMyftbRQqQJDQxoCERHlOMUbgH8FJmfmnyLiOIqH+Z4DPAnMzszbWlmrVA9vT0lDoOattBMo1vkmIk4DPg78Bvg7YBvwsrLNV5/rWcHQkIbWrymCYT5wNbCUYp3v64BlwBT4i5CR2pq3p6QhFhFvAU4D7szMy8p9R1CExqmZeWvrqpPqY2hIDbR9LKP8fDSwKTMf6tPnTRRrZWzJzHe2oExpt3l7SmqQcoptRsSUiLiKYp2MtRFxfURMjUIHxVoZ/wGc1dKCpd3glYbUYBFxB7AK+BRwBvAxYEpmbqjp85zMfKxFJUq7zXdPSQ0UEacDXcAJmbk1It4HfC4zN0TEK4BTgH/OzIdbWqi0m7w9JTXWY8CSMjA+RTGt9stl25PADGBiq4qTBssrDWmQalbkOwsYC7wqIt4MnAfMzMwtZdf3Ao9n5vJW1SoNlqEhDUI5W+qpiBgPfAF4NTAa+CrwILC8fHDvXcAHgNe0qlapEQwNaRBqHso7HvgexcN7CYynWEzpZxQLKz0AfKFcM0N61nL2lLSbat4vdRRwIcVU2tdk5h/Lq4u3U1x1jKIIlN6aFfykZyVDQxqk8qWElwPPAxYAXywXXOrbL9K/cHqWMzSkBoiIERQD32dQLOX6I+DfMvP3LS1MajBDQ2qgiDgc+DxwFHAXsBj4nlcY2lP4nIbUQJn5QGa+m+LV59OB4w0M7Um80pCGSER0Agf4uhDtSQwNSVJl3p6SJFVmaEiSKjM0JEmVGRqSpMoMDUlSZYaGJKmy/w+6ZvN1Ckd4nwAAAABJRU5ErkJggg==\n"
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "import numpy as np\n",
    "from matplotlib import  pyplot as plt\n",
    "trn_acc = [v['train'] for k, v in all_acc_dict.items()]\n",
    "val_acc = [v['val'] for k, v in all_acc_dict.items()]\n",
    "\n",
    "width = 0.3\n",
    "plt.bar(np.arange(len(trn_acc)), trn_acc, width = width,label = 'train')\n",
    "plt.bar(np.arange(len(val_acc))+width, val_acc, width = width,label = 'val')\n",
    "plt.xticks(np.arange(len(val_acc))+width/2,list(all_acc_dict.keys()), rotation = 60,fontsize = 14)\n",
    "plt.ylabel('accuracy',fontsize = 16)\n",
    "plt.legend(loc = 'lower right')\n",
    "plt.ylim(0.7,1)\n",
    "plt.xlim(-0.7,1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": "(60000, 10000)"
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(mnist), len(mnist_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch8cpu",
   "language": "python",
   "name": "torch8cpu"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}